import numpy as np
from typing import List
from scipy.optimize import linear_sum_assignment as linear_assignment
import torch
import torch.nn.functional as F

def split_cluster_acc_v2(y_true, y_pred, mask, args):
    args.train_session = 'online'
    """
    Calculate clustering accuracy. Require scikit-learn installed
    First compute linear assignment on all data, then look at how good the accuracy is on subsets

    # Arguments
        mask: Which instances come from old classes (True) and which ones come from new classes (False)
        y: true labels, numpy.array with shape `(n_samples,)`
        y_pred: predicted labels, numpy.array with shape `(n_samples,)`

    # Return
        accuracy, in [0,1]
    """
    y_true = y_true.astype(int)

    old_classes_gt = set(y_true[mask])
    new_classes_gt = set(y_true[~mask])

    assert y_pred.size == y_true.size
    D = max(y_pred.max(), y_true.max()) + 1
    w = np.zeros((D, D), dtype=int)
    for i in range(y_pred.size):
        w[y_pred[i], y_true[i]] += 1

    ind = linear_assignment(w.max() - w)
    ind = np.vstack(ind).T

    ind_map = {j: i for i, j in ind}
    total_acc = sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size

    old_acc = 0
    total_old_instances = 0
    for i in old_classes_gt:
        old_acc += w[ind_map[i], i]
        total_old_instances += sum(w[:, i])
    old_acc /= total_old_instances

    if args.train_session == 'online':
        new_acc = 0
        total_new_instances = 0
        for i in new_classes_gt:
            new_acc += w[ind_map[i], i]
            total_new_instances += sum(w[:, i])
        new_acc /= total_new_instances
    else:
        new_acc = None

    return total_acc, old_acc, new_acc


EVAL_FUNCS = {
    'v2': split_cluster_acc_v2,
}

def log_accs_from_preds(y_true, y_pred, mask, eval_funcs: List[str], save_name: str, T: int=None, print_output=False, args=None):

    mask = mask.astype(bool)
    y_true = y_true.astype(int)
    y_pred = y_pred.astype(int)

    for i, f_name in enumerate(eval_funcs):

        acc_f = EVAL_FUNCS[f_name]
        all_acc, old_acc, new_acc = acc_f(y_true, y_pred, mask, args)
        log_name = f'{save_name}_{f_name}'

        if i == 0:
            to_return = (all_acc, old_acc, new_acc)

        if print_output:
            print_str = f'Session {T}, {log_name}: All {all_acc:.4f} | Old {old_acc:.4f} | New {new_acc:.4f}'
            print(print_str)

    return to_return


def get_continual_gcd_data(datasetname, clf_to_use, stage):

    """ load dataset """

    return (train_labeled_features, train_labeled_features_labels,
            unique_labels_len,
            train_unlabeled_features, train_unlabeled_features_labels, mask_soft_train_unlabeled_features,
            test_features, test_features_labels, mask_soft_test_features)



def origin_classifier(image_features, clf_to_use):
    logits = image_features @ clf_to_use.T
    preds = np.argmax(logits, axis=1)

    return preds

def select_clf_to_use(test_features, clf_to_use, unique_labels_len):
    preds = origin_classifier(test_features, clf_to_use)
    unique_preds, counts = np.unique(preds, return_counts=True)

    sorted_indices = np.argsort(counts)
    sorted_indices_desc = sorted_indices[::-1]

    sorted_unique_preds = unique_preds[sorted_indices_desc]

    clf_to_use = clf_to_use[sorted_unique_preds]
    clf_to_use = clf_to_use[:unique_labels_len]

    return clf_to_use

def sinkhorn(M, tau_t=0.01, gamma=0, iter=20):
    row, col = M.shape
    P = F.softmax(M / tau_t, dim=1)
    P /= row
    if gamma > 0:
        q = torch.sum(P, dim=0, keepdim=True)
        q = q**gamma
        q /= torch.sum(q)
    for it in range(0, iter):

        P /= torch.sum(P, dim=0, keepdim=True)
        if gamma > 0:
            P *= q
        else:
            P /= col

        P /= torch.sum(P, dim=1, keepdim=True)
        P /= row
    P *= row
    return P


def image_opt(feat, init_classifier, plabel, lr=10, iter=2000, tau_i=0.04, alpha=0.5):
    ins, dim = feat.shape
    val, idx = torch.max(plabel, dim=1)
    mask = val > alpha
    plabel[mask, :] = 0
    plabel[mask, idx[mask]] = 1

    base = feat.T @ plabel
    classifier = init_classifier.clone()
    pre_norm = float('inf')
    for i in range(0, iter):
        prob = F.softmax(feat @ classifier / tau_i, dim=1)
        grad = feat.T @ prob - base
        temp = torch.norm(grad)
        if temp > pre_norm:
            lr /= 2.
        pre_norm = temp
        classifier -= (lr / (ins * tau_i)) * grad
        classifier = F.normalize(classifier, dim=0)
    return classifier


def train_inmap(text_classifier,
                   train_labeled_features, train_labeled_features_labels,
                   train_unlabeled_features, test_features):

    train_labeled_features = torch.tensor(train_labeled_features).cuda()
    train_unlabeled_features = torch.tensor(train_unlabeled_features).cuda()
    test_features = torch.tensor(test_features).cuda()
    text_classifier = torch.tensor(text_classifier).cuda().t()
    train_labeled_features_labels = torch.tensor(train_labeled_features_labels).cuda()

    logits_1 = train_labeled_features @ text_classifier
    logits_2 = train_unlabeled_features @ text_classifier
    logits_3 = test_features @ text_classifier

    concat_logits = torch.cat((logits_1, logits_2), dim=0)


    plabels = sinkhorn(concat_logits, tau_t=0.01, gamma=0, iter=20)


    concat_features = torch.cat((train_labeled_features, train_unlabeled_features), dim=0)

    image_classifier = image_opt(concat_features, text_classifier, plabels, lr=10, iter=2000, tau_i=0.04, alpha=0.5)

    image_classifier = image_classifier.t().cpu().numpy()

    return image_classifier

